from __future__ import division
import numbers

import torchdiffeq
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init



class nconv(nn.Module):
    def __init__(self):
        super(nconv,self).__init__()

    def forward(self,x, A):
        # x.shape = (batch, dim, nodes, seq_len)
        # A.shape = (node, node)
        x = torch.einsum('ncwl,vw->ncvl', (x, A))
        return x.contiguous()


# class wconv(nn.Module):
#     def __init__(self):
#         super(wconv, self).__init__()

#     def forward(self, x, W):
#         # x.shape = (batch, dim, nodes, seq_len)
#         # w.shape = (dim, dim)
#         x = torch.einsum('ncwl,vc->nvwl', (x, W))
#         return x.contiguous()


# class dy_nconv(nn.Module):
#     def __init__(self):
#         super(dy_nconv,self).__init__()

#     def forward(self,x, A):
#         x = torch.einsum('ncvl,nvwl->ncwl', (x,A))
#         return x.contiguous()


class linear(nn.Module):
    def __init__(self,c_in,c_out,bias=True):
        super(linear,self).__init__()
        self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0,0), stride=(1,1), bias=bias)
        self._reset_parameters()
    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
            else:
                nn.init.uniform_(p)
    def forward(self,x):
        return self.mlp(x)


# class prop(nn.Module):
#     def __init__(self,c_in,c_out,gdep,dropout,alpha):
#         super(prop, self).__init__()
#         self.nconv = nconv()
#         self.mlp = linear(c_in,c_out)
#         self.gdep = gdep
#         self.dropout = dropout
#         self.alpha = alpha

#     def forward(self,x,adj):
#         adj = adj + torch.eye(adj.size(0)).to(x.device)
#         d = adj.sum(1)
#         h = x
#         dv = d
#         a = adj / dv.view(-1, 1)
#         for i in range(self.gdep):
#             h = self.alpha*x + (1-self.alpha)*self.nconv(h,a)
#         ho = self.mlp(h)
#         return ho


# class mixprop(nn.Module):
#     def __init__(self, c_in, c_out, gdep, dropout, alpha):
#         super(mixprop, self).__init__()
#         self.nconv = nconv()
#         self.mlp = linear((gdep + 1) * c_in, c_out)
#         self.gdep = gdep
#         self.dropout = dropout
#         self.alpha = alpha
#         self._reset_parameters()
#     def _reset_parameters(self):
#         for p in self.parameters():
#             if p.dim() > 1:
#                 nn.init.xavier_uniform_(p)
#             else:
#                 nn.init.uniform_(p)
#     def forward(self, x, adj):
#         adj = adj + torch.eye(adj.size(0)).to(x.device)
#         d = adj.sum(1)  # (nodes)
#         h = x
#         out = [h]
#         a = adj / d.view(-1, 1)  # d.view(-1, 1).shape = (nodes, 1)
#         for i in range(self.gdep):
#             h = self.alpha * x + (1 - self.alpha) * self.nconv(h, a)
#             out.append(h)
#         ho = torch.cat(out, dim=1)
#         ho = self.mlp(ho)
#         return ho


# class dy_mixprop(nn.Module):
#     def __init__(self,c_in,c_out,gdep,dropout,alpha):
#         super(dy_mixprop, self).__init__()
#         self.nconv = dy_nconv()
#         self.mlp1 = linear((gdep+1)*c_in,c_out)
#         self.mlp2 = linear((gdep+1)*c_in,c_out)

#         self.gdep = gdep
#         self.dropout = dropout
#         self.alpha = alpha
#         self.lin1 = linear(c_in,c_in)
#         self.lin2 = linear(c_in,c_in)


#     def forward(self,x):
#         #adj = adj + torch.eye(adj.size(0)).to(x.device)
#         #d = adj.sum(1)
#         x1 = torch.tanh(self.lin1(x))
#         x2 = torch.tanh(self.lin2(x))
#         adj = self.nconv(x1.transpose(2,1),x2)
#         adj0 = torch.softmax(adj, dim=2)
#         adj1 = torch.softmax(adj.transpose(2,1), dim=2)

#         h = x
#         out = [h]
#         for i in range(self.gdep):
#             h = self.alpha*x + (1-self.alpha)*self.nconv(h,adj0)
#             out.append(h)
#         ho = torch.cat(out,dim=1)
#         ho1 = self.mlp1(ho)

#         h = x
#         out = [h]
#         for i in range(self.gdep):
#             h = self.alpha * x + (1 - self.alpha) * self.nconv(h, adj1)
#             out.append(h)
#         ho = torch.cat(out, dim=1)
#         ho2 = self.mlp2(ho)

#         return ho1+ho2


# class dilated_1D(nn.Module):
#     def __init__(self, cin, cout, dilation_factor=2):
#         super(dilated_1D, self).__init__()
#         self.tconv = nn.ModuleList()
#         self.kernel_set = [2,3,6,7]
#         self.tconv = nn.Conv2d(cin,cout,(1,7),dilation=(1,dilation_factor))

#     def forward(self,input):
#         x = self.tconv(input)
#         return x

class dilated_inception(nn.Module):
    def __init__(self, cin: int, cout: int, kernel_set: list, dilation_factor: int=1):
        super(dilated_inception, self).__init__()
        self.tconv = nn.ModuleList()
        self.kernel_set = kernel_set#[2, 3, 6, 7]
        self.cout = int(cout / len(self.kernel_set))
        for kern in self.kernel_set:
            self.tconv.append(nn.Conv2d(cin, self.cout, (1, kern), dilation=(1, dilation_factor)))
        self._reset_parameters()
    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
            else:
                nn.init.uniform_(p)
    def forward(self, input):
        x = []
        # print("input size", input.size())
        for i in range(len(self.kernel_set)):
            x.append(self.tconv[i](input))
        for i in range(len(self.kernel_set)):
            x[i] = x[i][..., -x[-1].size(3):]
        x = torch.cat(x, dim=1)
        return x

# class temporal_recur(nn.Module):
#     '''
#     A substitute of the dilated inception layer (Temporal Convolution)\n
#     Temporal Recurrent NN
#     '''
#     def __init__(self, c_in: int, c_out: int, 
#         name: str = 'GRU',
#         hidden_dim: int = 80,
#         num_layers: int = 1,
#         dropout: float = 0.0,
#     ):
#         super().__init__()

#         # Defining parameters
#         self.target_size = target_size = c_out
#         self.input_size = input_size = c_in
#         self.name = name
#         self.rnn = getattr(nn, name)(
#             input_size, hidden_dim, num_layers, batch_first=True, dropout=dropout
#         )

#         # The RNN module needs a linear layer V that transforms hidden states into outputs, individually
#         self.V = nn.Linear(hidden_dim, target_size)

#         self._reset_parameters()

#     def _reset_parameters(self):
#         for p in self.parameters():
#             if p.dim() > 1:
#                 nn.init.xavier_uniform_(p)
#             else:
#                 nn.init.uniform_(p)

#     def forward(self, X_in: torch.FloatTensor) -> torch.FloatTensor:
#         """
#         Making a forward pass of temporal RNN.

#         Arg types:
#             * **X_in** (Pytorch Float Tensor) - Input feature Tensor, with shape (batch_size, c_in, num_nodes, seq_len).

#         Return types:
#             * **X** (PyTorch Float Tensor) - Hidden representation for all nodes,
#             with shape (batch_size, c_out, num_nodes, seq_len). # -6
#         """
#         # data is of size (batch_size, input_length, input_size)
#         batch_size, c_in, num_nodes, seq_len = X_in.size()

#         X = X_in.permute(0,2,3,1).reshape(batch_size * num_nodes, seq_len, c_in)
#         out, last_hidden_state = self.rnn(X)
#         # out is of size (batch_size, seq_len, hidden_dim)
#         # Here, we apply the V matrix to every hidden state to produce the outputs
#         predictions = self.V(out)
#         # predictions is of size (batch_size, seq_len, target_size)
#         predictions = predictions.reshape(batch_size, num_nodes, seq_len, self.target_size).permute(0,3,1,2)

#         # returns outputs for all inputs, only the last one is needed for prediction time
#         return predictions

class graph_constructor(nn.Module):
    def __init__(self, nnodes: int, k: int, dim: int, alpha: float=3, static_feat=None):
        super(graph_constructor, self).__init__()
        self.nnodes = nnodes
        if static_feat is not None:
            xd = static_feat.shape[1] # _static_feature_dim
            self.lin1 = nn.Linear(xd, dim)
            self.lin2 = nn.Linear(xd, dim)
        else:
            self.emb1 = nn.Embedding(nnodes, dim)
            self.emb2 = nn.Embedding(nnodes, dim)
            self.lin1 = nn.Linear(dim,dim)
            self.lin2 = nn.Linear(dim,dim)

        # self.device = device
        self.k = k
        self.dim = dim
        self.alpha = alpha
        self.static_feat = static_feat

    def forward(self, idx):
        if self.static_feat is None:
            nodevec1 = self.emb1(idx)
            nodevec2 = self.emb2(idx)
        else:
            nodevec1 = self.static_feat[idx,:]
            nodevec2 = nodevec1

        nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1))
        nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2))

        a = torch.mm(nodevec1, nodevec2.transpose(1,0))-torch.mm(nodevec2, nodevec1.transpose(1,0))
        adj = F.relu(torch.tanh(self.alpha*a))
        mask = torch.zeros(idx.size(0), idx.size(0)).to(adj.device)
        mask.fill_(float('0'))
        # s1,t1 = adj.topk(self.k,1)
        s1, t1 = (adj + torch.rand_like(adj) * 0.01).topk(self.k, 1)  # bug fixed
        mask.scatter_(1,t1,s1.fill_(1))
        adj = adj*mask
        return adj

    def fullA(self, idx):
        if self.static_feat is None:
            nodevec1 = self.emb1(idx)
            nodevec2 = self.emb2(idx)
        else:
            nodevec1 = self.static_feat[idx,:]
            nodevec2 = nodevec1

        nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1))
        nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2))

        a = torch.mm(nodevec1, nodevec2.transpose(1,0))-torch.mm(nodevec2, nodevec1.transpose(1,0))
        adj = F.relu(torch.tanh(self.alpha*a))
        return adj



class CGPFunc(nn.Module):
    def __init__(self, c_in, c_out, init_alpha):
        super(CGPFunc, self).__init__()
        self.c_in = c_in
        self.c_out = c_out
        self.x0 = None
        self.adj = None
        self.nfe = 0
        self.alpha = init_alpha
        self.nconv = nconv()
        self.out = []

    def forward(self, t, x):
        adj = self.adj + torch.eye(self.adj.size(0)).to(x.device)
        d = adj.sum(1)
        _d = torch.diag(torch.pow(d, -0.5))
        adj_norm = torch.mm(torch.mm(_d, adj), _d)

        self.out.append(x)
        # print("x.size(): ", x.size())
        self.nfe += 1
        ax = self.nconv(x, adj_norm)
        f = 0.5 * self.alpha * (ax + x) - x
        # f = 0.5 * self.alpha * ax - (1 - 0.5 * self.alpha) * x
        return f


class CGPODEBlock(nn.Module):
    def __init__(self, cgpfunc, method, solver_opt, rtol, atol, adjoint, estimated_nfe):#, step_size, perturb
        super(CGPODEBlock, self).__init__()
        self.odefunc = cgpfunc
        self.method = method
        # self.step_size = step_size
        self.adjoint = adjoint
        # self.perturb = perturb
        self.solver_opt = solver_opt
        self.atol = atol
        self.rtol = rtol
        self.mlp = linear((estimated_nfe + 1) * self.odefunc.c_in, self.odefunc.c_out)

    def set_x0(self, x0):
        self.odefunc.x0 = x0.clone().detach()

    def set_adj(self, adj):
        self.odefunc.adj = adj

    def forward(self, x, t):
        self.integration_time = torch.tensor([0, t]).float().type_as(x)

        # print("before odeint --  x.size(): ", x.size()) #   torch.Size([7, 27, 20, 45])
        if self.adjoint:
            out = torchdiffeq.odeint_adjoint(self.odefunc, x, self.integration_time, rtol=self.rtol, atol=self.atol,
                                             method=self.method, options=self.solver_opt)
        else:
            out = torchdiffeq.odeint(self.odefunc, x, self.integration_time, rtol=self.rtol, atol=self.atol,
                                     method=self.method, options=self.solver_opt)
        # print("after odeint --  out.size(): ", out.size()) # torch.Size([2, 7, 27, 20, 45])

        outs = self.odefunc.out
        self.odefunc.out = []
        outs.append(out[-1])
        # print("CGPfunc c_in", self.odefunc.c_in)  #  32
        # print("CGPFunc nfe: ", self.odefunc.nfe)  #  1
        h_out = torch.cat(outs, dim=1)
        # print("h_out.size(): ", h_out.size())  #  torch.Size([7, 54, 20, 45])
        h_out = self.mlp(h_out)

        return h_out


class CGP(nn.Module):
    '''continuous graph propagation (CGP)'''
    def __init__(self, cin, cout, alpha=0.7, method='rk4', 
        solver_opt:dict={'step_size':0.25, 'perturb':False}, time=1.0,
        rtol=1e-4, atol=1e-3, adjoint=False
    ):#, step_size=1.0, perturb=False
        super(CGP, self).__init__()
        self.c_in = cin
        self.c_out = cout
        self.alpha = alpha
        self.integration_time = time

        if not is_solver_adaptive(method):# method in ['euler', 'rk4']
            step_size = solver_opt["step_size"]
            self.estimated_nfe = round(self.integration_time / step_size)
            self.estimated_nfe *= fixedSolversNFEfactor[method]
            # round(self.integration_time / (step_size / 4.0))
        elif is_solver_adaptive(method):
            # self.estimated_nfe = 1
            raise ValueError("Oops! The CGP solver is not available yet.")
        else:
            raise ValueError("Oops! The CGP solver is invaild.")

        self.CGPODE = CGPODEBlock(CGPFunc(self.c_in, self.c_out, self.alpha),
                                  method, solver_opt, rtol, atol, adjoint,
                                  self.estimated_nfe)#, step_size, perturb

    def forward(self, x, adj):
        self.CGPODE.set_x0(x)
        self.CGPODE.set_adj(adj)
        h = self.CGPODE(x, self.integration_time)
        return h


# temporal_recur, mixprop
# from layer import dilated_inception, CGP, graph_constructor

def is_solver_adaptive(method:str):
    '''
    confer: https://github.com/rtqichen/torchdiffeq#list-of-ode-solvers
    '''
    AdaptiveSolvers = ["dopri8","dopri5","bosh3","adaptive_heun"]
    FixedSolvers = ["euler", "midpoint", "rk4", "explicit_adams", "implicit_adams"]
    if method in AdaptiveSolvers:
        answer = True
    elif method in FixedSolvers:
        answer = False
    else:
        raise NotImplementedError
    return answer


class ODEFunc(nn.Module):
    def __init__(self, stnet):
        super(ODEFunc, self).__init__()
        self.stnet = stnet
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        x = self.stnet(x)
        return x


class ODEBlock(nn.Module):
    def __init__(self, odefunc:ODEFunc, method:str, solver_opt:dict, rtol:float, atol:float, adjoint:bool=False):
        super(ODEBlock, self).__init__()
        self.odefunc = odefunc
        self.method = method
        # self.step_size = step_size
        self.adjoint = adjoint
        # self.perturb = perturb
        self.solver_opt = solver_opt
        # solver_opt   were   dict(step_size=self.step_size, perturb=self.perturb)
        self.atol = atol
        self.rtol = rtol

    def forward(self, x, t):
        self.integration_time = torch.tensor([0, t]).float().type_as(x)
        if self.adjoint:
            out = torchdiffeq.odeint_adjoint(self.odefunc, x, self.integration_time, rtol=self.rtol, atol=self.atol,
                                             method=self.method, options=self.solver_opt)
        else:
            out = torchdiffeq.odeint(self.odefunc, x, self.integration_time, rtol=self.rtol, atol=self.atol,
                                     method=self.method, options=self.solver_opt)

        return out[-1]


class STBlock(nn.Module):
    '''
    MTGNN residual_channels & conv_channels & skip_channels  ->  same hidden_channels
    '''
    def __init__(self, receptive_field, dilation, 
        hidden_channels, dropout, method, solver_opt, time, alpha,
        rtol, atol, adjoint, kernel_set, #, step_size, perturb,
    ):
        super(STBlock, self).__init__()
        self.receptive_field = receptive_field
        self.intermediate_seq_len = receptive_field
        self.graph = None
        self.dropout = dropout
        self.new_dilation = 1
        self.dilation_factor = dilation
        self.filter_conv = dilated_inception(
            hidden_channels,
            hidden_channels,
            kernel_set=kernel_set,
            dilation_factor=self.new_dilation,
        )
        self.gate_conv = dilated_inception(
            hidden_channels,
            hidden_channels,
            kernel_set=kernel_set,
            dilation_factor=self.new_dilation,
        )
        intermediate_hc = self.gate_conv.cout * len(kernel_set)
        self.gconv_1 = CGP(intermediate_hc, hidden_channels, alpha=alpha,
                           method=method, solver_opt=solver_opt, time=time, 
                           rtol=rtol, atol=atol, adjoint=adjoint)
        self.gconv_2 = CGP(intermediate_hc, hidden_channels, alpha=alpha,
                           method=method, solver_opt=solver_opt, time=time, 
                           rtol=rtol, atol=atol, adjoint=adjoint)#, step_size=step_size, perturb=perturb

    def forward(self, x):
        x = x[..., -self.intermediate_seq_len:]
        for tconv in self.filter_conv.tconv:
            tconv.dilation = (1, self.new_dilation)
        for tconv in self.gate_conv.tconv:
            tconv.dilation = (1, self.new_dilation)

        filter = self.filter_conv(x)
        filter = torch.tanh(filter)
        gate = self.gate_conv(x)
        gate = torch.sigmoid(gate)
        x = filter * gate
        # print("filter * gate --  x.size(): ", x.size())  #  torch.Size([7, 27, 20, 45])

        self.new_dilation *= self.dilation_factor
        self.intermediate_seq_len = x.size(3)

        x = F.dropout(x, self.dropout, training=self.training)

        x = self.gconv_1(x, self.graph) + self.gconv_2(x, self.graph.transpose(1, 0))

        x = nn.functional.pad(x, (self.receptive_field - x.size(3), 0))

        return x

    def setGraph(self, graph):
        self.graph = graph

    def setIntermediate(self, dilation):
        self.new_dilation = dilation
        self.intermediate_seq_len = self.receptive_field

fixedSolversNFEfactor = {
    "euler":1, 
    "midpoint":2, 
    "rk4":4, 
    "explicit_adams":4, 
    "implicit_adams":4
}
#print([num for num in range(200) if is_prime(num)])
prime_list = [2, 3, 5, 7, 
            11, 13, 17, 19, 
            23, 29, 31, 37, 
            41, 43, 47, 53, 
            59, 61, 67, 71, 
            73, 79, 83, 89, 
            97, 101, 103, 107, 
            109, 113, 127, 131, 
            137, 139, 149, 151, 
            157, 163, 167, 173, 
            179, 181, 191, 193, 
            197, 199]

class MTGODE(nn.Module):
    '''
    ::nfe:: the number of function evaluations, which is equivalent 
    to describe how many times feature propagation is executed, 
    a.k.a. the propagation depth K in the discrete formulation.  \n

    ::solver_opt:: https://github.com/rtqichen/torchdiffeq/blob/master/FURTHER_DOCUMENTATION.md#solver-options \n

    _1, (TCN) temporal :: CTA nfe := time_1 // step_size_1 * factor \n
    _2, (GCN) spatial ::: CGP nfe := time_2 // step_size_2 * factor
    '''
    def __init__(self, buildA_true, num_nodes, #device, 
        dilation_exponential=1, kernel_set=prime_list[:10],#[2,3,5,7,11,13,17,19]
        predefined_A=None, static_feat=None, dropout=0,
        subgraph_size=20, node_dim=40, conv_channels=2**5, end_channels=2**6,
        seq_length=12, in_dim=2, out_dim=12, tanhalpha=3, 
        method_1='euler', time_1=1.2, solver_opt_1={"step_size":0.4, "perturb":False},
        method_2='euler', time_2=1.0, solver_opt_2={"step_size":0.25, "perturb":False}, 
        alpha=1.0, rtol=1e-4, atol=1e-3, adjoint=True,
        ln_affine=True):#, perturb=False, step_size_1=0.4, step_size_2=0.25

        super().__init__()

        self.integration_time = time_1

        if not is_solver_adaptive(method_1):# method_1 in ['euler', 'rk4']
            step_size = solver_opt_1["step_size"]
            self.estimated_nfe = round(self.integration_time / step_size) # 四舍五入, 半整数凑偶
            self.estimated_nfe *= fixedSolversNFEfactor[method_1]
            # round(self.integration_time / (step_size / 4.0))
        elif is_solver_adaptive(method_1):
            # self.estimated_nfe = 3
            raise ValueError("Oops! Temporal ODE solver is not available yet.")
        else:
            raise ValueError("Oops! Temporal ODE solver is invaild.")
        # estimated_nfe is "equivalent" to # MTGNN layers

        self.buildA_true = buildA_true
        self.num_nodes = num_nodes
        self.dropout = dropout
        self.predefined_A = predefined_A
        self.seq_length = seq_length
        self.ln_affine = ln_affine
        self.adjoint = adjoint

        self.start_conv = nn.Conv2d(in_channels=in_dim, out_channels=conv_channels, kernel_size=(1, 1))

        self.gc = graph_constructor(num_nodes, subgraph_size, node_dim, alpha=tanhalpha, static_feat=static_feat)#, device
        self.idx = torch.arange(self.num_nodes)#.to(device)

        ### _set_receptive_field
        max_kernel_size = kernel_set[-1]
        if dilation_exponential > 1:
            self.receptive_field = int(1 + (max_kernel_size - 1) * (dilation_exponential**self.estimated_nfe - 1) / (dilation_exponential - 1))
        else:
            self.receptive_field = self.estimated_nfe * (max_kernel_size - 1) + 1

        msg = f"MTGODE is built with receptive_field size '{self.receptive_field}' and the input seq_length '{self.seq_length}'. "
        if self.receptive_field >= self.seq_length:
            print(msg, '\nZero-padding is applied for lager receptive_field.')
        else:
            print(msg, '\nInput is trimmed for smaller receptive_field.')
            # raise NotImplementedError(f"MTGODE requires that the receptive_field size \
            #     '{self.receptive_field}' is not less than the input seq_length '{self.seq_length}'")

        if ln_affine:
            self.affine_weight = nn.Parameter(torch.Tensor(*(conv_channels, self.num_nodes)))  # C*H
            self.affine_bias = nn.Parameter(torch.Tensor(*(conv_channels, self.num_nodes)))  # C*H

        self.ODE = ODEBlock(ODEFunc(STBlock(receptive_field=self.receptive_field, dilation=dilation_exponential,
                                            hidden_channels=conv_channels, dropout=self.dropout, method=method_2, 
                                            solver_opt=solver_opt_2, time=time_2, alpha=alpha, rtol=rtol, atol=atol,
                                            adjoint=False, kernel_set=kernel_set,#, step_size=step_size_2, perturb=perturb
                )), method_1, solver_opt_1, rtol, atol, adjoint)#, step_size_1, perturb

        # self.end_conv_0 = nn.Conv2d(in_channels=conv_channels, out_channels=end_channels//2, kernel_size=(1, 1), bias=True)
        # self.end_conv_1 = nn.Conv2d(in_channels=end_channels//2, out_channels=end_channels, kernel_size=(1, 1), bias=True)
        self.end_conv_1 = nn.Conv2d(in_channels=conv_channels, out_channels=end_channels, kernel_size=(1, 1), bias=True)
        self.end_conv_2 = nn.Conv2d(in_channels=end_channels, out_channels=out_dim, kernel_size=(1, 1), bias=True)

        if ln_affine:
            self.reset_parameters()

    def reset_parameters(self):
        init.ones_(self.affine_weight)
        init.zeros_(self.affine_bias)

    def forward(self, input, idx=None):
        seq_len = input.size(3)
        assert seq_len == self.seq_length, 'input sequence length not equal to preset sequence length'

        if self.seq_length < self.receptive_field:
            input = nn.functional.pad(input, (self.receptive_field-self.seq_length, 0))
        elif self.seq_length > self.receptive_field:
            input = input[..., -self.receptive_field:]

        if self.buildA_true:
            if idx is None:
                adp = self.gc(self.idx.to(input.device))
            else:
                adp = self.gc(idx.to(input.device))
        else:
            adp = self.predefined_A.to(input.device)

        x = self.start_conv(input)
        # print("x.size(): ", x.size())

        if self.adjoint:
            self.ODE.odefunc.stnet.setIntermediate(dilation=1)
        self.ODE.odefunc.stnet.setGraph(adp)
        x = self.ODE(x, self.integration_time)
        self.ODE.odefunc.stnet.setIntermediate(dilation=1)

        x = x[..., -1:]
        x = F.layer_norm(x, tuple(x.shape[1:]), weight=None, bias=None, eps=1e-5)

        if self.ln_affine:
            if idx is None:
                x = torch.add(torch.mul(x, self.affine_weight[:, self.idx].unsqueeze(-1)), self.affine_bias[:, self.idx].unsqueeze(-1))  # C*H
            else:
                x = torch.add(torch.mul(x, self.affine_weight[:, idx].unsqueeze(-1)), self.affine_bias[:, idx].unsqueeze(-1))  # C*H

        # x = F.relu(self.end_conv_0(x))
        x = F.relu(self.end_conv_1(x))
        x = self.end_conv_2(x)

        return x
